{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stochastic Variational GP Regression with Contour Integral Quadrature\n",
    "\n",
    "\n",
    "## Overview\n",
    "\n",
    "This notebook demonstrates how to perform stochastic variational GP regression using **contour integral quadrature (CIQ) with msMINRES** as described in [Pleiss et al., 2020](https://arxiv.org/pdf/2006.11267.pdf).\n",
    "Contour integral quadrature can be used in place of standard SVGP when:\n",
    "\n",
    " - There are many inducing points (e.g. M > 5000)\n",
    " - The inducing points have special structure (e.g. lie on a grid)\n",
    "\n",
    "We'll give an overview of how to use CIQ-SVGP stochastic variational regression ((https://arxiv.org/pdf/1411.2005.pdf)) to rapidly train using minibatches on the `3droad` UCI dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Make plots inline\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "import os\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "\n",
    "\n",
    "# this is for running the notebook in our testing framework\n",
    "smoke_test = ('CI' in os.environ)\n",
    "\n",
    "\n",
    "if not smoke_test and not os.path.isfile('../3droad.mat'):\n",
    "    print('Downloading \\'3droad\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://www.dropbox.com/s/f6ow1i59oqx05pl/3droad.mat?dl=1', '../3droad.mat')\n",
    "\n",
    "if smoke_test:  # this is for running the notebook in our testing framework\n",
    "    X, y = torch.randn(10, 2), torch.randn(10)\n",
    "else:\n",
    "    data = torch.Tensor(loadmat('../3droad.mat')['data'])\n",
    "    X = data[:, :-2]\n",
    "    X = X - X.min(0)[0]\n",
    "    X = 2 * (X / X.max(0)[0]) - 1\n",
    "    y = data[:, -1]\n",
    "    y.sub_(y.mean(0)).div_(y.std(0))\n",
    "    \n",
    "    # Let's subsample the data\n",
    "    indices = torch.randperm(X.size(0))[:10000]\n",
    "    X = X[indices]\n",
    "    y = y[indices]\n",
    "\n",
    "\n",
    "train_n = int(floor(0.8 * len(X)))\n",
    "train_x = X[:train_n, :].contiguous()\n",
    "train_y = y[:train_n].contiguous()\n",
    "\n",
    "test_x = X[train_n:, :].contiguous()\n",
    "test_y = y[train_n:].contiguous()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DataLoaders with CIQ-SVGP\n",
    "\n",
    "CIQ offers computational speedups only when the **minibatch size is much smaller** than the number of inducing points.\n",
    "We find that a minibatch size of 256 often works well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "train_dataset = TensorDataset(train_x, train_y)\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "# Smaller batch sizes are better for CIQ\n",
    "\n",
    "test_dataset = TensorDataset(test_x, test_y)\n",
    "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Number of inducing points\n",
    "\n",
    "CIQ offers computational speedups when there are lots of inducing points.\n",
    "Here, we are choosing 2000 inducing points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "inducing_points = train_x[torch.randperm(train_x.size(0))[:2000]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIQ - SVGP models\n",
    "\n",
    "To use contour integral quadrature, simply replace `VariationalStrategy` with `CiqVariationalStrategy`.\n",
    "\n",
    "In this example, we are using a `NaturalVariationalStrategy`, as CIQ works best with natural gradient descent.\n",
    "(See [the NGD tutorial](./Natural_Gradient_Descent.ipynb) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPModel(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = gpytorch.variational.NaturalVariationalDistribution(inducing_points.size(0))\n",
    "        variational_strategy = gpytorch.variational.CiqVariationalStrategy(\n",
    "            self, inducing_points, variational_distribution, learn_inducing_locations=True\n",
    "        )\n",
    "        super(GPModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(\n",
    "            gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=2)\n",
    "        )\n",
    "        self.covar_module.base_kernel.initialize(lengthscale=0.01)  # Specific to the 3droad dataset\n",
    "        \n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "\n",
    "model = GPModel(inducing_points=inducing_points)\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    likelihood = likelihood.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "variational_ngd_optimizer = gpytorch.optim.NGD(model.variational_parameters(), num_data=train_y.size(0), lr=0.01)\n",
    "\n",
    "hyperparameter_optimizer = torch.optim.Adam([\n",
    "    {'params': model.hyperparameters()},\n",
    "    {'params': likelihood.parameters()},\n",
    "], lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33bc8e9c8f9a412b9c692c3deb807581",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=32.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=32.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=32.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=32.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "likelihood.train()\n",
    "mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))\n",
    "\n",
    "num_epochs = 1 if smoke_test else 4\n",
    "epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc=\"Epoch\")\n",
    "for i in epochs_iter:\n",
    "    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc=\"Minibatch\", leave=False)\n",
    "    \n",
    "    for x_batch, y_batch in minibatch_iter:\n",
    "        variational_ngd_optimizer.zero_grad()\n",
    "        hyperparameter_optimizer.zero_grad()\n",
    "        output = model(x_batch)\n",
    "        loss = -mll(output, y_batch)\n",
    "        minibatch_iter.set_postfix(loss=loss.item())\n",
    "        loss.backward()\n",
    "        variational_ngd_optimizer.step()\n",
    "        hyperparameter_optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test MAE: 0.6400326490402222\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "means = torch.tensor([0.])\n",
    "with torch.no_grad():\n",
    "    for x_batch, y_batch in test_loader:\n",
    "        preds = model(x_batch)\n",
    "        means = torch.cat([means, preds.mean.cpu()])\n",
    "means = means[1:]\n",
    "print('Test MAE: {}'.format(torch.mean(torch.abs(means - test_y.cpu()))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
